__Author__ = "Peter Herman"
__Project__ = "internet_and_trade"
__Created__ = "Sept 20, 2022"
__Description__ = '''This script performs a simulation of internet copnnectivity using revised estimates (terms split 
into income levels) and the Eora trade data.'''


import pandas as pd
import gegravity as ge
import gme as gme
from warnings import warn
import re as regex
from math import ceil, floor

# ----
# Specification
# ----
root_directory = "D:\work\Peter_Herman\projects\internet_and_trade_public_files\\"

# Define location of data inputs
# For X values, do not need services specific data as eora services trade will be used instead. Aggregate trade X values have fewer missing cost variable observations
estimation_data_local = "{}all_xvalues_aggregate.csv".format(root_directory)
fe_est_local = "{}all_countrypairFE_services.csv".format(root_directory)
betas_local = "{}beta_ests_services.csv".format(root_directory)
internet_local = "{}WB_internet.csv".format(root_directory)
eora_trade_local = "{}eora_trade_2016.csv".format(root_directory)
country_names_local = "{}country_year_coverage_v2_0.csv".format(root_directory)

# Name of model to use in model outputs
version = 'india_japan'
# Base year to use
base_year = 2016

# Column name to use as trade values from Eora dataset
trade_var = 'eora_trade_services_$M'

# Location to save model results
save_path = "{}".format(root_directory)



# High and high-middle income level countries
high_income_cntry = ["AND", "ATG", "ABW", "AUS", "AUT", "BHS", "BHR", "BRB", "BEL", "BMU", "VGB", "BRN",
                     "CAN", "CYM", "CHL", "HRV", "CUW", "CYP", "CZE", "DNK", "EST", "FRO", "FIN", "FRA", "PYF", "DEU",
                     "GIB", "GRC", "GRL", "GUM", "HKG", "HUN", "ISL", "IRL", "IMN", "ISR", "ITA", "JPN", "KOR", "KWT",
                     "LVA", "LIE", "LTU", "LUX", "MAC", "MLT", "MCO", "NRU", "NLD", "NCL", "NZL", "MNP", "NOR", "OMN",
                     "PLW", "POL", "PRT", "PRI", "QAT", "SMR", "SAU", "SYC", "SGP", "SXM", "SVK", "SVN", "ESP", "KNA",
                     "MAF", "SWE", "CHE", "TWN", "TTO", "TCA", "ARE", "GBR", "USA", "URY", "VIR", "ALB", "FJI", "NAM",
                     "ASM", "GAB", "MKD", "ARG", "GEO", "PLW", "ARM", "GRD", "PRY", "AZE", "GTM", "PER", "BLR", "GUY",
                     "RUS", "BLZ", "IRQ", "SRB", "BIH", "JAM", "ZAF", "BWA", "JOR", "LCA", "BRA", "KAZ", "VCT", "BLG",
                     "KSV", "SUR", "CHN", "LBY", "THA", "COL", "MYS", "TON", "CRI", "MDV", "TUR", "CUB", "MHL", "TKM",
                     "DMA", "MUS", "TUV", "DOM", "MEX", "GNQ", "MDA", "ECU", "MNE", "ROU", "PAN"]

# ----
# Determine possible country coverage
# ----


# Load pair fixed effect estimates
fe_ests = pd.read_csv(fe_est_local)

# Set some regression names
regression_names = ['connect_provision']
reg_spec = 'connect_provision'
# Prep pair fixed effect estimates
fe_ests.columns = [fe_ests.columns[0], fe_ests.columns[1]] + regression_names
fe_ests['var'] = fe_ests['exporter']+'_'+fe_ests['importer']

fe_pare_down = fe_ests.copy()
# Create list of unique country IDs from fixed effects
fe_countries = set(fe_ests['importer'].unique()).union(fe_ests['exporter'].unique())

# --- Determine pair FE coverage, which is the limiting factor for country coverage. Done by iteratively eliminating the countries  
#     with the worst coverage until a square dataset of exporter-importer fixed effects can be attained.
stop = 0
while stop<10000:
    if fe_pare_down.shape[0]==len(fe_countries)**2:
        # If set of fixed effects is balanced, break from loop. Otherwise, continue to eliminate countries with the worst coverage
        break
    # Count number of partners for each importer and exporter
    fe_imp = fe_pare_down.groupby(['importer']).count().reset_index()
    fe_exp = fe_pare_down.groupby(['exporter']).count().reset_index()
    # Combine to create number of fixed effects for each country as both an exporter and an importer
    fe_availability = fe_exp[['exporter','var']].merge(fe_imp[['importer','var']], how = 'outer', left_on = 'exporter', right_on='importer')
    # Determine minimum number of fixed effects available
    min_available = min(fe_imp['var'].min(), fe_exp['var'].min())
    # CReate frame keeping countries if they do not have the minimum number of importer or exporter fixed effects, dropping those with the least
    fe_availability = fe_availability.loc[(fe_availability['var_x']!=min_available) & (fe_availability['var_y']!=min_available),:]
    # Create list of countries that do not have the minimum number of available 
    fe_countries = set(fe_availability['importer'].unique()).union(fe_availability['exporter'].unique())
    # Pare down fixed effect estimates to include only the retained, non-minimum countries
    fe_pare_down = fe_pare_down.loc[fe_pare_down['exporter'].isin(fe_countries) & fe_pare_down['importer'].isin(fe_countries),:]
    stop += 1    
# Once data is square, use it
fe_ests = fe_pare_down




# ----
# Prep Input Data
# ----
# Load data
raw_input_data = pd.read_csv(estimation_data_local)

# Check for duplicate observations
dups = raw_input_data.loc[raw_input_data.duplicated(['importer','exporter','year'],keep=False),:].copy()
dups.sort_values(['importer','exporter','year'], inplace = True)

# Select base year data
input_data = raw_input_data.loc[raw_input_data['year']==base_year,:].copy()
input_data.rename(columns = {'trade':'itpd_trade_total'}, inplace = True)

data_summary = input_data.describe()

# Add eora trade data  (data is in $ Millions)
eora_trade = pd.read_csv(eora_trade_local)
eora_codes = eora_trade['exporter'].unique().tolist()
itpd_codes = input_data['exporter'].unique().tolist()


# Combine Eora with the covariates from the econometric estimations
input_data = input_data.merge(eora_trade, how = 'left', on = ['exporter','importer'],  validate = '1:1')

# Define countries to include in model
top_countries = fe_countries

# Create square data panel to use
top_country_pairs = [(i,j) for i in top_countries for j in top_countries]
top_country_pairs = pd.DataFrame(top_country_pairs, columns = ['exporter', 'importer'])
sub_panel = top_country_pairs.merge(input_data, how = 'left', on = ['exporter', 'importer'], validate = '1:1')

# Determine share of trade represented
total_trade = input_data[trade_var].sum()
sub_trade = sub_panel[trade_var].sum()
trade_covered = sub_trade/total_trade

# Check that the dimensions are right
sub_panel_missing = sub_panel.loc[sub_panel.isna().any(axis = 1),:]
if sub_panel_missing.shape[0]>0:
    raise ValueError("Missing some input data")


# Create output and expenditure
expenditures = sub_panel.groupby('importer').agg({trade_var:sum}).reset_index()
expenditures.columns = ['importer', 'expenditure']
outputs = sub_panel.groupby('exporter').agg({trade_var:sum}).reset_index()
outputs.columns = ['exporter', 'output']

# Add outputs and expenditures to input panel
sub_panel = sub_panel.merge(expenditures, how = 'outer', on = 'importer', validate='m:1')
sub_panel = sub_panel.merge(outputs, how = 'outer', on = 'exporter', validate='m:1')

# Check for any sub sample observations that weren't used in estimation
sum_info = sub_panel.describe()
est_indicators = [col for col in sum_info.columns if col.startswith('_est')]
for col in est_indicators:
    if sum_info.loc['min',col] <1:
        warn("{} has some omitted rows".format(col))

# ----
# Prep Coefficient Estimates
# ----

# Load other coefficient estimates
beta_ests = pd.read_csv(betas_local)
beta_ests.columns = ['var']+ regression_names

# Combine cost vars into a single dataframe
select_cols = ['var', reg_spec]
params = pd.concat([beta_ests[select_cols], fe_ests[select_cols]],axis =0)
# Drop vars without parameter ests
params = params.loc[~params[reg_spec].isna(),:]

cost_object = ge.CostCoeffs(params, identifier_col='var', coeff_col=reg_spec)

# Get a list of cost variables to use
cost_vars = params['var'].tolist()
# Remove all border-year fixed effects then add back in the baseline year's term
cost_vars = [var for var in cost_vars if not regex.match('foreign_\d\d\d\d',var)]
# Drop Constant as pair fixed effects were not estimated with constant
cost_vars = [var for var in cost_vars if var != '_cons']
cost_vars.append('foreign_{}'.format(base_year))

# Check that all pair fixed effects exist
needed_fes = (sub_panel['exporter']+'_'+sub_panel['importer']).tolist()
missing = [fe for fe in needed_fes if fe not in cost_vars]
if len(missing)>0:
    raise ValueError('There are country pairs without matching pair fixed effects.')

# ---
# Prep Cost Data
# ---

# Recreate dummies for pair fixed effects
fe_dummies = pd.get_dummies(fe_ests[['exporter', 'importer', 'var']], columns = ['var'], prefix='', prefix_sep='')

# add fixed effects to other input data
use_panel = sub_panel.merge(fe_dummies, how = 'outer', on = ['exporter','importer'], validate='1:1', indicator = True)

# Check for perfect 1:1 merge
fe_validator = use_panel['_merge'].unique()
for ind in fe_validator:
    if ind != 'both':
        raise ValueError('Fixed effect merge with input data not a perfect fit.')



# ----
# Prep GE inputs
# ----
data_summary = use_panel.describe().transpose()

# Create GME data object
gme_data = gme.EstimationData(use_panel, imp_var_name='importer', exp_var_name='exporter', year_var_name='year',
                              trade_var_name=trade_var)
# Create GME gravity model
gme_model = gme.EstimationModel(gme_data, lhs_var='trade', rhs_var=cost_vars)





# -----
# Create Provision model
# -----

# Define GE model
provision_model = ge.OneSectorGE(gme_model, year='2016', reference_importer='DEU', expend_var_name='expenditure',
                                 output_var_name='output', sigma = 7, cost_variables=cost_vars, cost_coeff_values=cost_object)

# Check for candidate OMR rescale values
# omr_eval = international_model.check_omr_rescale()

# Solve for baseline model
provision_model.build_baseline(omr_rescale=1)


# --- Define Japan-India experiment ----
# Specify FTA members
jpn_ind = ["JPN","IND"]

# Create copy of the baseline data
provision_exper = provision_model.baseline_data.copy()


# --- Define Experiment ---

# Set provision index = 1 for exports from 'JPN' (high income) to 'IND' (low income)
provision_exper.loc[provision_exper['exporter'].isin(['JPN']) &
                    provision_exper['importer'].isin(['IND']) &
                    (provision_exper['agree_pta']==1),  # there is an active trade agreement
                    'high_digital_ex'] =1

# Supply counterfactual data to the GE model
provision_model.define_experiment(provision_exper)



# --- Run Counterfactual ---

provision_model.simulate()

# Retrieve country-level results
provision_results = provision_model.country_results
# Examine the changes in trade costs at country and bilateral level
provision_shock = provision_model.trade_weighted_shock()
provision_levels = provision_model.calculate_levels()
provision_levels['export_change'] = provision_levels['experiment observed foreign exports'] - provision_levels['baseline observed foreign exports']



# ---
# Format and export tables
# ---

# Load country names to use instead of ISO codes
country_names = pd.read_csv(country_names_local)
# Replace some specific country names
country_names.replace({'Germany; West Germany':'Germany', 'Czech Republic':'Czechia', 'Egypt, Arab Rep.':'Egypt',
                       'Korea, South':'South Korea', 'Ceylon; Sri Lanka':'Sri Lanka', 'Malaya; Malaysia':'Malaysia',
                       'South Vietnam; Vietnam':'Vietnam', "Cambodia; Khmer Republic; Kampuchea":"Cambodia"}, inplace=True)

# Define a function to reformat a table of simulation results
def create_table(country_results, wide=True, path:str = None, round:int = 2):
    # Select outcomes to report
    results_subset = country_results[['foreign exports change (percent)', 'GDP change (percent)' ]]
    # Replace ISOs with country names
    table = country_names.merge(results_subset, how ='right', right_index=True, left_on='iso3')
    table.drop(['earliest_year', 'latest_year', 'iso3'], axis = 1, inplace=True)
    # Re-alphabetize 
    table.sort_values('country', inplace = True)
    table_len = ceil(country_results.shape[0]/2)
    if wide:
        # Split into two columns of countries instead of 
        table_wide = pd.concat([table.head(ceil(country_results.shape[0]/2)).reset_index(), pd.DataFrame([''] * ceil(country_results.shape[0]/2)),
                                table.tail(floor(country_results.shape[0]/2)).reset_index()], axis = 1)
        table_wide.drop('index', axis = 1, inplace=True)

        table_out = table_wide
    else:
        table_out = table

    if path:
        # Write to file if path is supplied using LaTeX notation
        table_out.to_csv(path, sep='&', lineterminator='\\\\\n', index=False,
                          float_format = '%.{}f'.format(round))
    return table_out


# Create full JPN-IND results
create_table(provision_results, wide = True, path = save_path+'jpn-ind_results_full.tex', round=3)

# Create sub table of JPN-IND results
provision_short = create_table(provision_results, wide = False,  round=3)
# Select Japan and India
focus_names = ["Japan", "India"]
provision_short = provision_short.loc[provision_short['country'].isin(focus_names),:]
# Select ever other country as Rest of World
provisions_other = provision_results.loc[~provision_results.index.isin(['IND','JPN']),['foreign exports change (percent)', 'GDP change (percent)' ]]
# Generate summary stats for Rest of World
provisions_other = provisions_other.describe()
row_exports = "{} ({})".format(round(provisions_other.loc['mean', 'foreign exports change (percent)'],3),
                               round(provisions_other.loc['std', 'foreign exports change (percent)'],3))
row_tot = "{} ({})".format(round(provisions_other.loc['mean', 'GDP change (percent)'],3),
                               round(provisions_other.loc['std', 'GDP change (percent)'],3))
row_addition = pd.DataFrame([('Rest of World', row_exports, row_tot)], columns = provision_short.columns)


provision_short.sort_values('country', inplace=True)
provision_short = provision_short.round(3)
# Combine Japan, India and Rest of World and save
provision_short = pd.concat([provision_short, row_addition], axis = 0)
provision_short.to_csv(save_path+"jpn-ind_results_subset.tex", sep='&', lineterminator='\\\\\n', index=False,
                       float_format = '%.3f')


# -----
# Post Analysis for Paper discussion
# -----

# Create list of low income countries in the model
low_income_cntry = [cntry for cntry in top_countries if cntry not in high_income_cntry]

# Create list of high and low internet countries (lower than median)
internet_data = pd.read_csv(internet_local)
internet_data = internet_data.loc[internet_data['Time']==2016,:]
internet_stats = internet_data['Individuals using the Internet (% of population) [IT.NET.USER.ZS]'].describe()
high_internet = internet_data.loc[internet_data['Individuals using the Internet (% of population) [IT.NET.USER.ZS]']>internet_stats['50%'],'Country Code'].tolist()
high_internet.sort()
low_int_model = [cntry for cntry in top_countries if cntry not in high_internet]

# Generate trade changes in levels for each country
prov_levels = provision_model.calculate_levels()
prov_levels['export_change'] = prov_levels['experiment observed foreign exports'] - prov_levels['baseline observed foreign exports']

# Generate bilateral trade changes
prov_bilat = provision_model.calculate_levels('bilateral')
prov_bilat.sort_values('trade change (observed level)', inplace = True)

# Extract India and Japan bilateral trade values
ind_jpn_trade = prov_bilat.loc[prov_bilat['exporter'].isin(['JPN','IND']) & prov_bilat['importer'].isin(['JPN','IND']),:]


# Print India and Japan GDPs
for iso in jpn_ind:
    try:
        print("{} GDP change $M: {}".format(iso,
                              provision_model.country_set[iso].baseline_gdp * provision_model.country_set[iso].gdp_change/100))
    except:
        pass


# Compute GDP changes in levels
gdp_outcomes = list()
for name, cntry in provision_model.country_set.items():
    counter_gdp = cntry.baseline_gdp * (1+cntry.gdp_change/100)
    gdp_diff = counter_gdp - cntry.baseline_gdp
    gdp_outcomes.append((name, cntry.baseline_gdp, counter_gdp, gdp_diff))
gdp_outcomes = pd.DataFrame(gdp_outcomes, columns = ['country', 'baseline_gdp', 'counter_gdp', 'gdp_difference'])
gdp_outcomes.sort_values('gdp_difference', inplace = True)


# Define function to compute GDP changes by country or group
def compute_group_gdp(members, results, label):
    # Compute GDP aggregates (excluding Japan and India)
    group_results = results.loc[results['country'].isin(members),:].copy()
    if members == ["IND"] or members == ["JPN"]:
        # If IND or JPN is specified as THE group, skip next step
        pass
    else:
        # Drop IND or JPN from group results
        group_results = group_results.loc[~group_results['country'].isin(['IND','JPN']),:].copy()
    # Sum modeled baseline and modeled GDP in levels over group
    agg_results = group_results.sum()
    # Compute change in levels
    agg_results['$M gdp change'] =  agg_results['counter_gdp'] - agg_results['baseline_gdp']
    # Compute % change
    agg_results['% gdp change'] = 100* agg_results['$M gdp change']/agg_results['baseline_gdp']
    # Set group label
    agg_results['country'] = label
    return agg_results


# GDP effects for India and Japan
ind_gdp = compute_group_gdp(["IND"], gdp_outcomes, 'India')
jpn_gdp = compute_group_gdp(["JPN"], gdp_outcomes, 'Japan')

# Specific countries or regions (EU, USA, and Rest of World)
EU = use_panel.loc[(use_panel['member_eu_joint']==1),'importer'].unique().tolist()
eu_gdp = compute_group_gdp(EU, gdp_outcomes, 'EU')
usa_gdp = compute_group_gdp(['USA'], gdp_outcomes, 'USA')

row_countries = [iso for iso in top_countries if iso not in EU+['IND','JPN','USA']]
row_gdp = compute_group_gdp(row_countries, gdp_outcomes, 'Rest of world')

# Combine GDP ests for all countries/groups
gdp_combined = pd.concat([ind_gdp, jpn_gdp, eu_gdp, usa_gdp, row_gdp], axis = 1)
gdp_combined = gdp_combined.transpose()



# ----
# Trade by country/group
# ----

# Collect bilateral results in changes and levels
bilat_levels = provision_model.bilateral_trade_results.reset_index()
bilat_calcualted = provision_model.calculate_levels('bilateral')
bilat_levels = bilat_levels.merge(bilat_calcualted[['exporter','importer','baseline observed trade']], how = 'outer', on = ['exporter','importer'])


# Compute change in trade by group
trade_by_group = dict()
for side in ['importer', 'exporter']:
    single_countrie_ests = list()
    for cntry in [(['JPN'], 'Japan'), (['IND'],'India'),(['USA'], 'USA'), (EU,'EU'), (row_countries, 'Rest of world')]:
        single_country = bilat_levels.loc[bilat_levels[side].isin(cntry[0]),:]
        # Drop domestic trade
        single_country = single_country.loc[single_country['importer']!= single_country['exporter'],:]
        # Select imports/exports for desired country/group
        single_country[side] = cntry[1]
        # Sum values over group
        single_country = single_country.groupby(side).agg({'baseline modeled trade':'sum',
                                                'experiment trade':'sum', 'baseline observed trade':'sum'}).reset_index()
        #  Define percent change using modeled trade values
        single_country['% change in {}s'.format(side[0:6])] = 100*(single_country['experiment trade'] -
                                                                   single_country['baseline modeled trade'])/single_country['baseline modeled trade']
        # Compute level change using observed levels and estimated/modeled change
        single_country['$M change in {}s'.format(side[0:6])] = (single_country['% change in {}s'.format(side[0:6])]/100)*single_country['baseline observed trade']
        # Relabel column and add to list of group calculations
        single_country.rename(columns = {side:'country'}, inplace = True)
        single_countrie_ests.append(single_country)
    # Combine computations of all groups
    combined = pd.concat(single_countrie_ests, axis = 0)
    trade_by_group[side] = combined

# Merge exporter and importer values and select certain columns for output
all_trade = trade_by_group['exporter'].merge(trade_by_group['importer'], on = ['country'])
all_trade = all_trade[['country','% change in exports',
                       '$M change in exports', '% change in imports', '$M change in imports']]


# Combine group trade values with GDP values
trade_gdp_effects = all_trade.merge(gdp_combined, how = 'outer', on = 'country')
tex_trade_gdp = trade_gdp_effects[['country', '% change in exports', '$M change in exports', '% change in imports',
                                   '$M change in imports', '% gdp change', '$M gdp change']].copy()

# Round values for paper tables
for col in tex_trade_gdp.columns:
    if col in ['% gdp change']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(4)
    if col in ['$M change in exports', '$M change in imports', '$M gdp change']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(0)
    if col in [ '% change in exports',  '% change in imports']:
        tex_trade_gdp[col] = tex_trade_gdp[col].astype(float).round(4)

# Save to csv and as a TeX formatted table
tex_trade_gdp.to_csv("{}IND-JPN_trade_gdp_impacts_by_group.tex".format(save_path), sep='&', lineterminator='\\\\\n', index = False)
tex_trade_gdp.to_csv("{}IND-JPN_trade_gdp_impacts_by_group.csv".format(save_path), index = False)










